# -*- coding: utf-8 -*-

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import re
import sys
import threading
import time
from unittest import mock

from google import auth
from google.auth import credentials as auth_credentials
from google.cloud import aiplatform
import vertexai
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aiplatform_utils
from google.cloud.aiplatform.metadata import metadata
from google.cloud.aiplatform_v1.services import (
    evaluation_service as gapic_evaluation_services,
)
from google.cloud.aiplatform_v1.types import (
    evaluation_service as gapic_evaluation_service_types,
)
from google.cloud.aiplatform_v1beta1.services import (
    evaluation_service as gapic_evaluation_services_preview,
)
from google.cloud.aiplatform_v1beta1.types import content
from google.cloud.aiplatform_v1beta1.types import (
    evaluation_service as gapic_evaluation_service_types_preview,
)
from vertexai import evaluation
from vertexai import generative_models
from vertexai.evaluation import _base as eval_base
from vertexai.evaluation import _evaluation
from vertexai.evaluation import eval_task
from vertexai.evaluation import utils
from vertexai.evaluation.metrics import _rouge
from vertexai.evaluation.metrics import metric_prompt_template
from vertexai.evaluation.metrics import (
    metric_prompt_template_examples,
)
from vertexai.evaluation.metrics import pairwise_metric
from vertexai.evaluation.metrics import pointwise_metric
from vertexai.preview import evaluation as evaluation_preview
from vertexai.preview import reasoning_engines
from vertexai.preview.evaluation import (
    _pre_eval_utils as pre_eval_utils_preview,
)
from vertexai.preview.evaluation import (
    constants as constants_preview,
)
from vertexai.preview.evaluation import utils as utils_preview
from vertexai.preview.evaluation.metrics import (
    _default_templates as default_templates_preview,
)
from vertexai.preview.evaluation.metrics import (
    custom_output_config,
)
from vertexai.preview.evaluation.metrics import (
    metric_prompt_template_examples as metric_prompt_template_examples_preview,
)
from vertexai.preview.evaluation.metrics import (
    pairwise_metric as pairwise_metric_preview,
)
from vertexai.preview.evaluation.metrics import (
    pointwise_metric as pointwise_metric_preview,
)
from vertexai.preview.evaluation.metrics import (
    rubric_based_metric,
)
from google.cloud.aiplatform.utils.gcs_utils import blob_from_uri
import numpy as np
import pandas as pd
import pytest


AutoraterConfig = evaluation_preview.AutoraterConfig
EvalTask = eval_task.EvalTask
EvalTaskPreview = evaluation_preview.eval_task.EvalTask
Pointwise = metric_prompt_template_examples.MetricPromptTemplateExamples.Pointwise
PointwisePreview = (
    evaluation_preview.metrics.metric_prompt_template_examples.MetricPromptTemplateExamples.Pointwise
)
Pairwise = metric_prompt_template_examples.MetricPromptTemplateExamples.Pairwise
PairwisePreview = (
    evaluation_preview.metrics.metric_prompt_template_examples.MetricPromptTemplateExamples.Pairwise
)
ContentMap = gapic_evaluation_service_types_preview.ContentMap
Content = content.Content
Part = content.Part
RubricGenerationConfig = evaluation_preview.RubricGenerationConfig


_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_TEST_BUCKET = "gs://test-bucket"
_TEST_FILE_NAME = "test-file-name.csv"
_AUTORATER_INSTRUCTION = """
You are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models.
"""
_METRIC_DEFINITION = "You will be assessing Text Quality"
_CRITERIA = {
    "Coherence": ("The response presents ideas in a logical and organized manner."),
    "Fluency": "The text flows smoothly and naturally",
}
_POINTWISE_RATING_RUBRIC = {
    "3": "(Good). Well-written.",
    "2": "(Ok). Adequate writing with decent coherence and fluency.",
    "1": "(Bad). Poorly written.",
}
_PAIRWISE_RATING_RUBRIC = {
    "A": "Response A answers better than response B.",
    "SAME": "Response A and B answers equally well.",
    "B": "Response B answers better than response A.",
}
_EVALUATION_STEPS = {
    "STEP 1": "Assess grammar correctness",
    "STEP 2": "Assess word choice and flow",
}
_TEST_POINTWISE_METRIC = pointwise_metric.PointwiseMetric(
    metric="test_pointwise_metric",
    metric_prompt_template=metric_prompt_template.PointwiseMetricPromptTemplate(
        metric_definition=_METRIC_DEFINITION,
        criteria=_CRITERIA,
        rating_rubric=_POINTWISE_RATING_RUBRIC,
        evaluation_steps=_EVALUATION_STEPS,
    ),
)
_TEST_POINTWISE_METRIC_FREE_STRING = pointwise_metric.PointwiseMetric(
    metric="test_pointwise_metric_str", metric_prompt_template="abc: {abc}"
)
_TEST_PAIRWISE_METRIC = pairwise_metric.PairwiseMetric(
    metric="test_pairwise_metric",
    metric_prompt_template=metric_prompt_template.PairwiseMetricPromptTemplate(
        metric_definition=_METRIC_DEFINITION,
        criteria=_CRITERIA,
        rating_rubric=_PAIRWISE_RATING_RUBRIC,
        evaluation_steps=_EVALUATION_STEPS,
    ),
)
_TEST_COMET = pointwise_metric.Comet(
    version="COMET_22_SRC_REF",
    source_language="en",
    target_language="zh",
)
_TEST_METRICX = pointwise_metric.MetricX(
    version="METRICX_24_SRC",
    source_language="en",
    target_language="zh",
)
_TEST_METRICS = (
    "exact_match",
    "bleu",
    "rouge_1",
    "rouge_2",
    "rouge_l",
    "rouge_l_sum",
    Pointwise.COHERENCE,
    Pointwise.FLUENCY,
    Pointwise.SAFETY,
    Pointwise.GROUNDEDNESS,
    Pointwise.SUMMARIZATION_QUALITY,
    Pointwise.VERBOSITY,
    Pointwise.QUESTION_ANSWERING_QUALITY,
    _TEST_POINTWISE_METRIC,
    _TEST_PAIRWISE_METRIC,
)
_AUTORATER_CONFIG = AutoraterConfig(
    autorater_model="test_autorater_model",
    sampling_count=6,
    flip_enabled=True,
)
_TEST_EVAL_DATASET_WITHOUT_PROMPT = pd.DataFrame(
    {
        "response": ["test", "text"],
        "reference": ["test", "ref"],
        "context": ["test", "context"],
        "instruction": ["test", "instruction"],
    }
)
_TEST_EVAL_DATASET_WITHOUT_RESPONSE = pd.DataFrame(
    {
        "prompt": ["test", "prompt"],
        "reference": ["test", "ref"],
        "context": ["test", "context"],
        "instruction": ["test", "instruction"],
    }
)
_TEST_AGENT_EVAL_DATASET_WITHOUT_RESPONSE = pd.DataFrame(
    {
        "prompt": ["test_input1", "test_input2"],
        "reference_trajectory": [
            [{"tool_name": "test_tool1"}, {"tool_name": "test_tool2"}],
            [{"tool_name": "test_tool3"}, {"tool_name": "test_tool4"}],
        ],
    },
)
_TEST_EVAL_DATASET_ALL_INCLUDED = pd.DataFrame(
    {
        "prompt": ["test_prompt", "text_prompt"],
        "response": ["test", "text"],
        "reference": ["test", "ref"],
        "context": ["test", "context"],
        "instruction": ["test", "instruction"],
        "source": ["test", "source"],
    }
)
_TEST_EVAL_DATASET_ALL_INCLUDED_DEFAULT_FIELDS = pd.DataFrame(
    {
        "prompt": ["test_prompt", "text_prompt"],
        "response": ["test", "text"],
        "baseline_model_response": ["test", "ref"],
        "context": ["test", "context"],
        "instruction": ["test", "instruction"],
    }
)
_TEST_EVAL_DATASET_PROMPT_RESPONSE = pd.DataFrame(
    {
        "prompt": ["test_prompt", "text_prompt", "test_prompt_3"],
        "response": ["test", "text", "test_response_3"],
    }
)
_TEST_EVAL_DATASET_SINGLE = pd.DataFrame({"prompt": ["test_prompt", "text_prompt"]})
_TEST_JSONL_FILE_CONTENT = """{"prompt": "prompt", "reference": "reference"}\n
{"prompt":"test", "reference": "test"}\n
"""
_TEST_CSV_FILE_CONTENT = """reference,context,instruction\ntest,test,test\n
text,text,text\n
"""
_TEST_EXPERIMENT = "test-experiment"
_TEST_CSV = pd.DataFrame(
    columns={
        "response": ["text"],
        "reference": ["ref"],
    }
)
_TEST_POINTWISE_METRIC_WITH_RAW_OUTPUT = pointwise_metric_preview.PointwiseMetric(
    metric="test_pointwise_metric",
    metric_prompt_template=metric_prompt_template.PointwiseMetricPromptTemplate(
        metric_definition=_METRIC_DEFINITION,
        criteria=_CRITERIA,
        rating_rubric=_POINTWISE_RATING_RUBRIC,
        evaluation_steps=_EVALUATION_STEPS,
    ),
    custom_output_config=custom_output_config.CustomOutputConfig(
        return_raw_output=True
    ),
)
_TEST_PAIRWISE_METRIC_WITH_RAW_OUTPUT = pairwise_metric_preview.PairwiseMetric(
    metric="test_pairwise_metric",
    metric_prompt_template=metric_prompt_template.PairwiseMetricPromptTemplate(
        metric_definition=_METRIC_DEFINITION,
        criteria=_CRITERIA,
        rating_rubric=_PAIRWISE_RATING_RUBRIC,
        evaluation_steps=_EVALUATION_STEPS,
    ),
    custom_output_config=custom_output_config.CustomOutputConfig(
        return_raw_output=True
    ),
)
_TEST_MULTIMODAL_MODEL_DATASET = pd.DataFrame(
    {
        "prompt": ["test_prompt", "text_prompt"],
        "response": [
            (
                '{"contents": [{"parts": [{"file_data": {"mime_type": "image/png",'
                ' "file_uri": "gs://test-bucket/image3.png"}}]}]}'
            ),
            (
                '{"contents": [{"parts": [{"file_data": {"mime_type": "image/png",'
                ' "file_uri": "gs://test-bucket/image4.png"}}]}]}'
            ),
        ],
    }
)
_EXPECTED_POINTWISE_PROMPT_TEMPLATE = """
# Instruction
hello


# Evaluation
## Metric Definition
this is eval metric

## Criteria
metric1: summarization

## Rating Rubric
0: bad
1: good

## Evaluation Steps
step1: start
step2: finish

## Evaluation Examples
Q: hi A: hello


# User Inputs and AI-generated Response
## User Inputs
### country
{country}




## AI-generated Response
{response}
"""
_EXPECTED_POINTWISE_PROMPT_TEMPLATE_WITH_DEFAULT_VALUES = """
# Instruction
You are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models. We will provide you with the user prompt and an AI-generated responses.
You should first read the user input carefully for analyzing the task, and then evaluate the quality of the responses based on the Criteria provided in the Evaluation section below.
You will assign the response a rating following the Rating Rubric and Evaluation Steps. Give step by step explanations for your rating, and only choose ratings from the Rating Rubric.


# Evaluation
## Criteria
Coherence: The response presents ideas in a logical and organized manner.
Fluency: The text flows smoothly and naturally

## Rating Rubric
1: (Bad). Poorly written.
2: (Ok). Adequate writing with decent coherence and fluency.
3: (Good). Well-written.

## Evaluation Steps
Step 1: Assess the response in aspects of all criteria provided. Provide assessment according to each criterion.
Step 2: Score based on the rating rubric. Give a brief rationale to explain your evaluation considering each individual criterion.


# User Inputs and AI-generated Response
## User Inputs



## AI-generated Response
{response}
"""

_EXPECTED_PAIRWISE_PROMPT_TEMPLATE = """
# Instruction
hello


# Evaluation
## Metric Definition
this is eval metric

## Criteria
metric1: summarization

## Rating Rubric
A: good
B: good

## Evaluation Steps
step1: start
step2: finish

## Evaluation Examples
Q: hi A: hello


# User Inputs and AI-generated Responses
## User Inputs
### country
{country}


## AI-generated Responses
### Response A
{baseline_model_response}

### Response B
{response}
"""

_EXPECTED_PAIRWISE_PROMPT_TEMPLATE_WITH_DEFAULT_VALUES = """
# Instruction
You are an expert evaluator. Your task is to evaluate the quality of the responses generated by two AI models. We will provide you with the user input and a pair of AI-generated responses (Response A and Response B).
You should first read the user input carefully for analyzing the task, and then evaluate the quality of the responses based on based on the Criteria provided in the Evaluation section below.
You will first judge responses individually, following the Rating Rubric and Evaluation Steps. Then you will give step by step explanations for your judgement, compare results to declare the winner based on the Rating Rubric and Evaluation Steps.


# Evaluation
## Criteria
Coherence: The response presents ideas in a logical and organized manner.
Fluency: The text flows smoothly and naturally

## Rating Rubric
A: Response A answers better than response B.
B: Response B answers better than response A.
SAME: Response A and B answers equally well.

## Evaluation Steps
Step 1: Analyze Response A based on all the Criteria.
Step 2: Analyze Response B based on all the Criteria.
Step 3: Compare the overall performance of Response A and Response B based on your analyses and assessment.
Step 4: Output your preference of "A", "SAME" or "B" to the pairwise_choice field according to the Rating Rubrics.
Step 5: Output your assessment reasoning in the explanation field


# User Inputs and AI-generated Responses
## User Inputs

## AI-generated Responses
### Response A
{baseline_model_response}

### Response B
{response}
"""

_EXPECTED_EVAL_DATASET_PROMPT_RESPONSE_WITH_RUBRICS = pd.DataFrame(
    {
        "prompt": ["test_prompt", "text_prompt", "test_prompt_3"],
        "response": ["test", "text", "test_response_3"],
        "rubrics": [
            ["test_rubric1", "test_rubric2"],
            ["test_rubric1", "test_rubric2"],
            ["test_rubric1", "test_rubric2"],
        ],
    }
)

_MOCK_RUNNABLE_INFERENCE_RESPONSE = [
    {
        "input": "test_input",
        "output": "test_output",
        "intermediate_steps": [
            [{"kwargs": {"tool": "test_tool1"}, "tool_output": "test_tool_output"}],
            [{"kwargs": {"tool": "test_tool2"}, "tool_output": "test_tool_output"}],
        ],
    },
    {
        "input": "test_input",
        "output": "test_output",
        "intermediate_steps": [
            [{"kwargs": {"tool": "test_tool2"}, "tool_output": "test_tool_output"}],
            [{"kwargs": {"tool": "test_tool3"}, "tool_output": "test_tool_output"}],
        ],
    },
]

_MOCK_RUBRIC_BASED_INSTRUCTION_FOLLOWING_RESULT = (
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        rubric_based_instruction_following_result=(
            gapic_evaluation_service_types_preview.RubricBasedInstructionFollowingResult(
                score=0.5,
                rubric_critique_results=[
                    gapic_evaluation_service_types_preview.RubricCritiqueResult(
                        rubric="rubric_1",
                        verdict=True,
                    ),
                    gapic_evaluation_service_types_preview.RubricCritiqueResult(
                        rubric="rubric_2",
                    ),
                ],
            )
        )
    ),
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        rubric_based_instruction_following_result=(
            gapic_evaluation_service_types_preview.RubricBasedInstructionFollowingResult(
                score=0.0,
                rubric_critique_results=[
                    gapic_evaluation_service_types_preview.RubricCritiqueResult(
                        rubric="rubric_1",
                    ),
                    gapic_evaluation_service_types_preview.RubricCritiqueResult(
                        rubric="rubric_2",
                    ),
                ],
            )
        )
    ),
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        rubric_based_instruction_following_result=(
            gapic_evaluation_service_types_preview.RubricBasedInstructionFollowingResult(
                score=1.0,
                rubric_critique_results=[
                    gapic_evaluation_service_types_preview.RubricCritiqueResult(
                        rubric="rubric_1",
                        verdict=True,
                    ),
                ],
            )
        )
    ),
)

_MOCK_EXACT_MATCH_RESULT = (
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        exact_match_results=gapic_evaluation_service_types.ExactMatchResults(
            exact_match_metric_values=[
                gapic_evaluation_service_types.ExactMatchMetricValue(score=1.0),
            ]
        )
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        exact_match_results=gapic_evaluation_service_types.ExactMatchResults(
            exact_match_metric_values=[
                gapic_evaluation_service_types.ExactMatchMetricValue(score=0.0),
            ]
        )
    ),
)
_MOCK_TRAJECTORY_EXACT_MATCH_RESULT = (
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        trajectory_exact_match_results=gapic_evaluation_service_types_preview.TrajectoryExactMatchResults(
            trajectory_exact_match_metric_values=[
                gapic_evaluation_service_types_preview.TrajectoryExactMatchMetricValue(
                    score=1.0
                ),
            ]
        )
    ),
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        trajectory_exact_match_results=gapic_evaluation_service_types_preview.TrajectoryExactMatchResults(
            trajectory_exact_match_metric_values=[
                gapic_evaluation_service_types_preview.TrajectoryExactMatchMetricValue(
                    score=0.0
                ),
            ]
        )
    ),
)
_MOCK_POINTWISE_RESULT = (
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types.PointwiseMetricResult(
            score=5, explanation="explanation"
        )
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types.PointwiseMetricResult(
            score=4, explanation="explanation"
        )
    ),
)
_MOCK_PAIRWISE_RESULT = (
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        pairwise_metric_result=gapic_evaluation_service_types.PairwiseMetricResult(
            pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.BASELINE,
            explanation="explanation",
        )
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        pairwise_metric_result=gapic_evaluation_service_types.PairwiseMetricResult(
            pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.BASELINE,
            explanation="explanation",
        )
    ),
)
_MOCK_SUMMARIZATION_QUALITY_RESULT = (
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types.PointwiseMetricResult(
            score=5, explanation="explanation"
        )
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types.PointwiseMetricResult(
            score=4, explanation="explanation"
        )
    ),
)
_MOCK_SUMMARIZATION_QUALITY_RESULT_PREVIEW = (
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types_preview.PointwiseMetricResult(
            score=5, explanation="explanation"
        )
    ),
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types_preview.PointwiseMetricResult(
            score=4, explanation="explanation"
        )
    ),
)
_MOCK_COHERENCE_RESULT = (
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types_preview.PointwiseMetricResult(
            score=5, explanation="explanation"
        )
    ),
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types_preview.PointwiseMetricResult(
            score=4, explanation="explanation"
        )
    ),
)
_MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT = (
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        pairwise_metric_result=gapic_evaluation_service_types.PairwiseMetricResult(
            pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.BASELINE,
            explanation="explanation",
        )
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        pairwise_metric_result=gapic_evaluation_service_types.PairwiseMetricResult(
            pairwise_choice=gapic_evaluation_service_types.PairwiseChoice.CANDIDATE,
            explanation="explanation",
        )
    ),
)
_MOCK_MODEL_INFERENCE_RESPONSE = generative_models.GenerationResponse.from_dict(
    {
        "candidates": [
            {
                "content": {"parts": [{"text": "test_response"}]},
            }
        ]
    }
)
MOCK_EVAL_RESULT = eval_base.EvalResult(
    summary_metrics={
        "row_count": 1,
        "mock_metric/mean": 1.0,
        "mock_metric/std": np.nan,
    },
    metrics_table=pd.DataFrame(
        {
            "response": ["test"],
            "mock_metric": [1.0],
        }
    ),
)
_MOCK_POINTWISE_RESULT_WITH_RAW_OUTPUT = (
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types_preview.PointwiseMetricResult(
            custom_output=gapic_evaluation_service_types_preview.CustomOutput(
                raw_outputs=gapic_evaluation_service_types_preview.RawOutput(
                    raw_output=["raw_output_sample_1.1", "raw_output_sample_1.2"],
                ),
            )
        )
    ),
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        pointwise_metric_result=gapic_evaluation_service_types_preview.PointwiseMetricResult(
            custom_output=gapic_evaluation_service_types_preview.CustomOutput(
                raw_outputs=gapic_evaluation_service_types_preview.RawOutput(
                    raw_output=["raw_output_sample_2.1", "raw_output_sample_2.2"],
                ),
            )
        )
    ),
)
_MOCK_PAIRWISE_RESULT_WITH_RAW_OUTPUT = (
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        pairwise_metric_result=gapic_evaluation_service_types_preview.PairwiseMetricResult(
            custom_output=gapic_evaluation_service_types_preview.CustomOutput(
                raw_outputs=gapic_evaluation_service_types_preview.RawOutput(
                    raw_output=["raw_output_sample_1.1", "raw_output_sample_1.2"],
                ),
            )
        )
    ),
    gapic_evaluation_service_types_preview.EvaluateInstancesResponse(
        pairwise_metric_result=gapic_evaluation_service_types_preview.PairwiseMetricResult(
            custom_output=gapic_evaluation_service_types_preview.CustomOutput(
                raw_outputs=gapic_evaluation_service_types_preview.RawOutput(
                    raw_output=["raw_output_sample_2.1", "raw_output_sample_2.2"],
                ),
            )
        )
    ),
)
_MOCK_MODEL_RUBRIC_GENERATION_RESPONSE = generative_models.GenerationResponse.from_dict(
    {
        "candidates": [
            {
                "content": {
                    "parts": [
                        {
                            "text": """```json{"questions": ["test_rubric1", "test_rubric2"]}```"""
                        }
                    ]
                },
            }
        ]
    }
)
_MOCK_MODEL_RUBRIC_GENERATION_RESPONSE_WITH_ADDITIONAL = generative_models.GenerationResponse.from_dict(
    {
        "candidates": [
            {
                "content": {
                    "parts": [
                        {
                            "text": """```json{"questions": ["test_rubric1", "test_rubric2"], "desc": "test_desc"}```"""
                        }
                    ]
                },
            }
        ]
    }
)
_UNPARSED_RUBRIC = """```json{"questions": ["test_rubric"]}```"""
_INVALID_UNPARSED_RUBRIC = """```json{["questions": ["test_rubric"]]}```"""
_EXPECTED_ROUGE_REQUESTS = (
    gapic_evaluation_service_types.EvaluateInstancesRequest(
        location=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
        rouge_input=gapic_evaluation_service_types.RougeInput(
            metric_spec=gapic_evaluation_service_types.RougeSpec(
                rouge_type="rougeLsum", use_stemmer=True, split_summaries=True
            ),
            instances=[
                gapic_evaluation_service_types.RougeInstance(
                    prediction="test_response", reference="test"
                ),
            ],
        ),
    ),
    gapic_evaluation_service_types.EvaluateInstancesRequest(
        location=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
        rouge_input=gapic_evaluation_service_types.RougeInput(
            metric_spec=gapic_evaluation_service_types.RougeSpec(
                rouge_type="rougeLsum", use_stemmer=True, split_summaries=True
            ),
            instances=[
                gapic_evaluation_service_types.RougeInstance(
                    prediction="test_response", reference="ref"
                ),
            ],
        ),
    ),
)
_MOCK_ROUGE_RESULT = (
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        rouge_results=gapic_evaluation_service_types.RougeResults(
            rouge_metric_values=[
                gapic_evaluation_service_types.RougeMetricValue(score=1.0)
            ]
        )
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        rouge_results=gapic_evaluation_service_types.RougeResults(
            rouge_metric_values=[
                gapic_evaluation_service_types.RougeMetricValue(score=0.5)
            ]
        )
    ),
)
_EXPECTED_COLUMN_MAPPING = {
    "context": "context",
    "reference": "reference",
    "response": "response",
    "instruction": "instruction",
    "prompt": "prompt",
    "source": "source",
}
_MOCK_MODEL_BASED_TRANSLATION_RESULT = (
    # The order of the responses is important.
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        comet_result=gapic_evaluation_service_types.CometResult(score=0.1)
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        metricx_result=gapic_evaluation_service_types.MetricxResult(score=5)
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        comet_result=gapic_evaluation_service_types.CometResult(score=0.9)
    ),
    gapic_evaluation_service_types.EvaluateInstancesResponse(
        metricx_result=gapic_evaluation_service_types.MetricxResult(score=20)
    ),
)


@pytest.fixture(scope="module")
def google_auth_mock():
    with mock.patch.object(auth, "default") as google_auth_mock:
        google_auth_mock.return_value = (
            auth_credentials.AnonymousCredentials(),
            _TEST_PROJECT,
        )
        yield google_auth_mock


@pytest.fixture
def mock_experiment_tracker():
    with mock.patch.object(
        metadata, "_experiment_tracker", autospec=True
    ) as mock_experiment_tracker:
        yield mock_experiment_tracker


@pytest.fixture
def mock_storage_blob_from_string():
    if hasattr(blob_from_uri.__globals__["storage"].Blob, "from_uri"):
        with mock.patch.object(
            blob_from_uri.__globals__["storage"].Blob, "from_uri"
        ) as mock_blob_from_uri:
            yield mock_blob_from_uri
    else:
        with mock.patch.object(
            blob_from_uri.__globals__["storage"].Blob, "from_string"
        ) as mock_blob_from_string:
            yield mock_blob_from_string


@pytest.mark.usefixtures("google_auth_mock")
class TestEvaluation:
    def setup_method(self):
        vertexai.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
        )

    def teardown_method(self):
        initializer.global_pool.shutdown(wait=True)

    def test_create_eval_task(self):

        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=_TEST_METRICS,
            experiment=_TEST_EXPERIMENT,
        )

        assert test_eval_task.dataset.equals(_TEST_EVAL_DATASET_ALL_INCLUDED)
        assert test_eval_task.metrics == _TEST_METRICS
        assert test_eval_task.experiment == _TEST_EXPERIMENT
        assert test_eval_task._metric_column_mapping == _EXPECTED_COLUMN_MAPPING

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_exact_match_metric(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        eval_dataset = pd.DataFrame(
            {
                "response": ["test", "text"],
                "reference": ["test", "ref"],
            }
        )
        test_metrics = ["exact_match"]
        test_eval_task = EvalTask(dataset=eval_dataset, metrics=test_metrics)
        mock_metric_results = _MOCK_EXACT_MATCH_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["exact_match/mean"] == 0.5
        assert test_result.summary_metrics["exact_match/std"] == pytest.approx(0.7, 0.1)
        assert list(test_result.metrics_table.columns.values) == [
            "response",
            "reference",
            "exact_match/score",
        ]
        assert test_result.metrics_table[["response", "reference"]].equals(eval_dataset)
        assert list(test_result.metrics_table["exact_match/score"].values) == [
            1.0,
            0.0,
        ]

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_pointwise_metrics(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [_TEST_POINTWISE_METRIC]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
        )
        mock_metric_results = _MOCK_POINTWISE_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["test_pointwise_metric/mean"] == 4.5
        assert test_result.summary_metrics[
            "test_pointwise_metric/std"
        ] == pytest.approx(0.7, 0.1)
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "response",
                "context",
                "instruction",
                "reference",
                "test_pointwise_metric/score",
                "test_pointwise_metric/explanation",
                "source",
            ]
        )
        assert test_result.metrics_table["response"].equals(
            _TEST_EVAL_DATASET_ALL_INCLUDED["response"]
        )
        assert test_result.metrics_table["prompt"].equals(
            _TEST_EVAL_DATASET_ALL_INCLUDED["prompt"]
        )
        scores = list(test_result.metrics_table["test_pointwise_metric/score"].values)
        assert scores == [5, 4] or scores == [4, 5]
        assert list(
            test_result.metrics_table["test_pointwise_metric/explanation"].values
        ) == [
            "explanation",
            "explanation",
        ]

    def test_compute_pointwise_metrics_free_string(self):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=[_TEST_POINTWISE_METRIC_FREE_STRING],
            metric_column_mapping={"abc": "prompt"},
        )
        mock_metric_results = _MOCK_POINTWISE_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["test_pointwise_metric_str/mean"] == 4.5
        assert test_result.summary_metrics[
            "test_pointwise_metric_str/std"
        ] == pytest.approx(0.7, 0.1)
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "response",
                "context",
                "instruction",
                "reference",
                "test_pointwise_metric_str/score",
                "test_pointwise_metric_str/explanation",
                "source",
            ]
        )
        assert test_result.metrics_table["response"].equals(
            _TEST_EVAL_DATASET_ALL_INCLUDED["response"]
        )
        assert test_result.metrics_table["prompt"].equals(
            _TEST_EVAL_DATASET_ALL_INCLUDED["prompt"]
        )
        assert list(
            test_result.metrics_table["test_pointwise_metric_str/score"].values
        ) == [5, 4]
        assert list(
            test_result.metrics_table["test_pointwise_metric_str/explanation"].values
        ) == [
            "explanation",
            "explanation",
        ]

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_pointwise_metrics_metric_prompt_template_example(
        self, api_transport
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = _MOCK_MODEL_INFERENCE_RESPONSE
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"
        test_metrics = [Pointwise.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_WITHOUT_RESPONSE, metrics=test_metrics
        )
        mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate(
                model=mock_model,
                prompt_template="{instruction} test prompt template {context}",
            )

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["summarization_quality/mean"] == 4.5
        assert test_result.summary_metrics[
            "summarization_quality/std"
        ] == pytest.approx(0.7, 0.1)
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "context",
                "instruction",
                "reference",
                "prompt",
                "response",
                "summarization_quality/score",
                "summarization_quality/explanation",
            ]
        )
        assert list(
            test_result.metrics_table["summarization_quality/score"].values
        ) == [5, 4]
        assert list(
            test_result.metrics_table["summarization_quality/explanation"].values
        ) == [
            "explanation",
            "explanation",
        ]

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_pointwise_metrics_without_model_inference(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [Pointwise.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
        )
        mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["summarization_quality/mean"] == 4.5
        assert test_result.summary_metrics[
            "summarization_quality/std"
        ] == pytest.approx(0.7, 0.1)
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "context",
                "instruction",
                "reference",
                "prompt",
                "response",
                "summarization_quality/score",
                "summarization_quality/explanation",
                "source",
            ]
        )
        assert list(
            test_result.metrics_table["summarization_quality/score"].values
        ) == [5, 4]
        assert list(
            test_result.metrics_table["summarization_quality/explanation"].values
        ) == [
            "explanation",
            "explanation",
        ]

    @pytest.mark.skipif(
        sys.version_info >= (3, 13), reason="flaky race condition in python 3.13"
    )
    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_model_based_translation_metrics_without_model_inference(
        self, api_transport
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [_TEST_COMET, _TEST_METRICX]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
        )

        mock_metric_results = _MOCK_MODEL_BASED_TRANSLATION_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["comet/mean"] == 0.5
        assert test_result.summary_metrics["metricx/mean"] == 12.5
        assert test_result.summary_metrics["comet/std"] == pytest.approx(0.5, 0.6)
        assert test_result.summary_metrics["metricx/std"] == pytest.approx(10, 11)
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "context",
                "instruction",
                "reference",
                "prompt",
                "response",
                "source",
                "comet/score",
                "metricx/score",
            ]
        )
        assert list(test_result.metrics_table["comet/score"].values) == [0.1, 0.9]
        assert list(test_result.metrics_table["metricx/score"].values) == [5, 20]

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_automatic_metrics_with_custom_metric_spec(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = _MOCK_MODEL_INFERENCE_RESPONSE
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"
        test_metrics = [
            _rouge.Rouge(
                rouge_type="rougeLsum",
                use_stemmer=True,
                split_summaries=True,
            )
        ]
        test_eval_task = evaluation.EvalTask(
            dataset=_TEST_EVAL_DATASET_WITHOUT_RESPONSE, metrics=test_metrics
        )
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=_MOCK_ROUGE_RESULT,
        ) as mock_evaluate_instances:
            test_result = test_eval_task.evaluate(
                model=mock_model,
            )

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["rouge/mean"] == 0.75
        assert test_result.summary_metrics["rouge/std"] == pytest.approx(0.35, 0.1)
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "reference",
                "response",
                "context",
                "instruction",
                "rouge/score",
            ]
        )
        assert list(test_result.metrics_table["rouge/score"].values) == [1, 0.5]

        api_requests = [
            call.kwargs["request"] for call in mock_evaluate_instances.call_args_list
        ]
        assert api_requests == list(_EXPECTED_ROUGE_REQUESTS)

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_pairwise_metrics(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        mock_baseline_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_baseline_model.generate_content.return_value = (
            _MOCK_MODEL_INFERENCE_RESPONSE
        )
        mock_baseline_model._model_name = "publishers/google/model/gemini-pro"
        mock_candidate_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_candidate_model.generate_content.return_value = (
            _MOCK_MODEL_INFERENCE_RESPONSE
        )
        mock_candidate_model._model_name = "publishers/google/model/gemini-pro"
        _TEST_PAIRWISE_METRIC._baseline_model = mock_baseline_model
        test_metrics = [_TEST_PAIRWISE_METRIC]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_WITHOUT_RESPONSE, metrics=test_metrics
        )
        mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate(
                model=mock_candidate_model,
                prompt_template="{instruction} test prompt template {context}",
            )
        _TEST_PAIRWISE_METRIC._baseline_model = None
        assert test_result.summary_metrics["row_count"] == 2
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "context",
                "instruction",
                "prompt",
                "response",
                "reference",
                "baseline_model_response",
                "test_pairwise_metric/pairwise_choice",
                "test_pairwise_metric/explanation",
            ]
        )
        assert list(
            test_result.metrics_table["test_pairwise_metric/pairwise_choice"].values
        ) == ["BASELINE", "CANDIDATE"]
        assert list(
            test_result.metrics_table["test_pairwise_metric/explanation"].values
        ) == [
            "explanation",
            "explanation",
        ]
        assert set(test_result.summary_metrics.keys()) == set(
            [
                "row_count",
                "test_pairwise_metric/candidate_model_win_rate",
                "test_pairwise_metric/baseline_model_win_rate",
            ]
        )
        assert (
            test_result.summary_metrics["test_pairwise_metric/candidate_model_win_rate"]
            == 0.5
        )
        assert (
            test_result.summary_metrics["test_pairwise_metric/baseline_model_win_rate"]
            == 0.5
        )

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_pairwise_metrics_metric_prompt_template_example(
        self, api_transport
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        eval_dataset = _TEST_EVAL_DATASET_WITHOUT_RESPONSE.copy(deep=True)
        eval_dataset.insert(1, "baseline_model_response", ["baseline", "response"])
        mock_candidate_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_candidate_model.generate_content.return_value = (
            _MOCK_MODEL_INFERENCE_RESPONSE
        )
        mock_candidate_model._model_name = "publishers/google/model/gemini-pro"
        test_metrics = [Pairwise.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTask(dataset=eval_dataset, metrics=test_metrics)
        mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate(
                model=mock_candidate_model,
                prompt_template="{instruction} test prompt template {context}",
            )

        assert test_result.summary_metrics["row_count"] == 2
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "context",
                "instruction",
                "prompt",
                "response",
                "reference",
                "baseline_model_response",
                "pairwise_summarization_quality/pairwise_choice",
                "pairwise_summarization_quality/explanation",
            ]
        )
        assert list(
            test_result.metrics_table[
                "pairwise_summarization_quality/pairwise_choice"
            ].values
        ) == ["BASELINE", "CANDIDATE"]
        assert list(
            test_result.metrics_table[
                "pairwise_summarization_quality/explanation"
            ].values
        ) == [
            "explanation",
            "explanation",
        ]
        assert set(test_result.summary_metrics.keys()) == set(
            [
                "row_count",
                "pairwise_summarization_quality/candidate_model_win_rate",
                "pairwise_summarization_quality/baseline_model_win_rate",
            ]
        )
        assert (
            test_result.summary_metrics[
                "pairwise_summarization_quality/candidate_model_win_rate"
            ]
            == 0.5
        )
        assert (
            test_result.summary_metrics[
                "pairwise_summarization_quality/baseline_model_win_rate"
            ]
            == 0.5
        )

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_pairwise_metrics_without_model_inference(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        eval_dataset = _TEST_EVAL_DATASET_ALL_INCLUDED.copy(deep=True)
        eval_dataset.insert(1, "baseline_model_response", ["baseline", "response"])
        test_metrics = [Pairwise.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTask(dataset=eval_dataset, metrics=test_metrics)
        mock_metric_results = _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()

        assert test_result.summary_metrics["row_count"] == 2
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "response",
                "baseline_model_response",
                "reference",
                "context",
                "instruction",
                "pairwise_summarization_quality/pairwise_choice",
                "pairwise_summarization_quality/explanation",
                "source",
            ]
        )
        assert list(
            test_result.metrics_table[
                "pairwise_summarization_quality/pairwise_choice"
            ].values
        ) == ["BASELINE", "CANDIDATE"]
        assert list(
            test_result.metrics_table[
                "pairwise_summarization_quality/explanation"
            ].values
        ) == [
            "explanation",
            "explanation",
        ]
        assert set(test_result.summary_metrics.keys()) == set(
            [
                "row_count",
                "pairwise_summarization_quality/candidate_model_win_rate",
                "pairwise_summarization_quality/baseline_model_win_rate",
            ]
        )
        assert (
            test_result.summary_metrics[
                "pairwise_summarization_quality/candidate_model_win_rate"
            ]
            == 0.5
        )
        assert (
            test_result.summary_metrics[
                "pairwise_summarization_quality/baseline_model_win_rate"
            ]
            == 0.5
        )

    @pytest.mark.skipif(
        sys.version_info >= (3, 13), reason="flaky race condition in python 3.13"
    )
    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_multiple_metrics(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        mock_baseline_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_baseline_model.generate_content.return_value = (
            _MOCK_MODEL_INFERENCE_RESPONSE
        )
        mock_baseline_model._model_name = "publishers/google/model/gemini-pro"
        _TEST_PAIRWISE_METRIC._baseline_model = mock_baseline_model
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = _MOCK_MODEL_INFERENCE_RESPONSE
        mock_model._model_name = "publishers/google/model/gemini-pro"
        test_metrics = [
            "exact_match",
            Pointwise.SUMMARIZATION_QUALITY,
            _TEST_PAIRWISE_METRIC,
        ]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_WITHOUT_RESPONSE, metrics=test_metrics
        )
        mock_metric_results = (
            _MOCK_EXACT_MATCH_RESULT[0],
            _MOCK_SUMMARIZATION_QUALITY_RESULT[0],
            _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT[0],
            _MOCK_EXACT_MATCH_RESULT[1],
            _MOCK_SUMMARIZATION_QUALITY_RESULT[1],
            _MOCK_PAIRWISE_SUMMARIZATION_QUALITY_RESULT[1],
        )
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate(
                model=mock_model,
                prompt_template="{instruction} test prompt template {context}",
            )

        _TEST_PAIRWISE_METRIC._baseline_model = None
        assert test_result.summary_metrics["row_count"] == 2
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "response",
                "baseline_model_response",
                "reference",
                "context",
                "instruction",
                "exact_match/score",
                "summarization_quality/score",
                "summarization_quality/explanation",
                "test_pairwise_metric/pairwise_choice",
                "test_pairwise_metric/explanation",
            ]
        )
        assert list(test_result.metrics_table["exact_match/score"].values) == [
            1.0,
            0.0,
        ]

        assert list(
            test_result.metrics_table["test_pairwise_metric/pairwise_choice"].values
        ) == ["BASELINE", "CANDIDATE"]
        assert list(
            test_result.metrics_table["test_pairwise_metric/explanation"].values
        ) == [
            "explanation",
            "explanation",
        ]
        assert (
            test_result.summary_metrics["test_pairwise_metric/candidate_model_win_rate"]
            == 0.5
        )
        assert (
            test_result.summary_metrics["test_pairwise_metric/baseline_model_win_rate"]
            == 0.5
        )

        assert list(
            test_result.metrics_table["summarization_quality/score"].values
        ) == [5, 4]
        assert list(
            test_result.metrics_table["summarization_quality/explanation"].values
        ) == [
            "explanation",
            "explanation",
        ]

    def test_eval_result_experiment_run_logging(self):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=[Pointwise.FLUENCY],
            experiment=_TEST_EXPERIMENT,
        )

        with mock.patch.multiple(
            metadata._experiment_tracker,
            _experiment=mock.MagicMock(name=_TEST_EXPERIMENT),
            _experiment_run=None,
            set_experiment=mock.DEFAULT,
            reset=mock.DEFAULT,
        ):
            with mock.patch.multiple(
                vertexai.preview,
                start_run=mock.MagicMock(),
                log_params=mock.DEFAULT,
                log_metrics=mock.DEFAULT,
            ) as mock_metadata:
                with mock.patch.object(
                    target=_evaluation,
                    attribute="evaluate",
                    side_effect=[MOCK_EVAL_RESULT],
                ):
                    _ = test_eval_task.evaluate()

        mock_metadata["log_metrics"].assert_called_once_with(
            {"row_count": 1, "mock_metric/mean": 1.0, "mock_metric/std": "NaN"}
        )

    @pytest.mark.skipif(
        sys.version_info >= (3, 13), reason="flaky race condition in python 3.13"
    )
    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_rubric_based_instruction_following_metric(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = ["rubric_based_instruction_following"]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_PROMPT_RESPONSE,
            metrics=test_metrics,
        )
        mock_metric_results = _MOCK_RUBRIC_BASED_INSTRUCTION_FOLLOWING_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()

        assert len(test_result.metrics_table) == 3
        assert (
            test_result.metrics_table.iloc[0][
                "rubric_based_instruction_following/score"
            ]
            == 0.5
        )
        assert test_result.metrics_table.iloc[0][
            "rubric_based_instruction_following/per_rubric_result"
        ] == [
            {"rubric": "rubric_1", "verdict": True},
            {"rubric": "rubric_2", "verdict": False},
        ]
        assert (
            test_result.metrics_table.iloc[1][
                "rubric_based_instruction_following/score"
            ]
            == 0.0
        )
        assert test_result.metrics_table.iloc[1][
            "rubric_based_instruction_following/per_rubric_result"
        ] == [
            {"rubric": "rubric_1", "verdict": False},
            {"rubric": "rubric_2", "verdict": False},
        ]
        assert (
            test_result.metrics_table.iloc[2][
                "rubric_based_instruction_following/score"
            ]
            == 1.0
        )
        assert test_result.metrics_table.iloc[2][
            "rubric_based_instruction_following/per_rubric_result"
        ] == [{"rubric": "rubric_1", "verdict": True}]
        assert (
            test_result.summary_metrics["rubric_based_instruction_following/mean"]
            == 0.5
        )

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_pointwise_metrics_with_raw_output(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [_TEST_POINTWISE_METRIC_WITH_RAW_OUTPUT]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
        )
        mock_metric_results = _MOCK_POINTWISE_RESULT_WITH_RAW_OUTPUT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()
        assert test_result.summary_metrics["row_count"] == 2
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "response",
                "context",
                "source",
                "instruction",
                "reference",
                "test_pointwise_metric/raw_output",
            ]
        )
        assert test_result.metrics_table["response"].equals(
            _TEST_EVAL_DATASET_ALL_INCLUDED["response"]
        )
        assert test_result.metrics_table["prompt"].equals(
            _TEST_EVAL_DATASET_ALL_INCLUDED["prompt"]
        )
        assert list(
            test_result.metrics_table["test_pointwise_metric/raw_output"].values
        ) == [
            ["raw_output_sample_1.1", "raw_output_sample_1.2"],
            ["raw_output_sample_2.1", "raw_output_sample_2.2"],
        ]

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_pairwise_metrics_with_raw_output(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [_TEST_PAIRWISE_METRIC_WITH_RAW_OUTPUT]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED_DEFAULT_FIELDS, metrics=test_metrics
        )
        mock_metric_results = _MOCK_PAIRWISE_RESULT_WITH_RAW_OUTPUT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()
        assert test_result.summary_metrics["row_count"] == 2
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "response",
                "baseline_model_response",
                "context",
                "instruction",
                "test_pairwise_metric/raw_output",
            ]
        )
        assert list(
            test_result.metrics_table["test_pairwise_metric/raw_output"].values
        ) == [
            ["raw_output_sample_1.1", "raw_output_sample_1.2"],
            ["raw_output_sample_2.1", "raw_output_sample_2.2"],
        ]

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_compute_rubric_based_metric(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_rubric_based_metric = rubric_based_metric.RubricBasedMetric(
            generation_config=RubricGenerationConfig(
                prompt_template="abc",
            ),
            critique_metric=metric_prompt_template_examples_preview.MetricPromptTemplateExamples.Pointwise.SUMMARIZATION_QUALITY,
        )
        test_metrics = [test_rubric_based_metric]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED_DEFAULT_FIELDS, metrics=test_metrics
        )
        mock_metric_results = _MOCK_POINTWISE_RESULT_WITH_RAW_OUTPUT
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = (
            _MOCK_MODEL_RUBRIC_GENERATION_RESPONSE
        )
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate()
        assert test_result.summary_metrics["row_count"] == 2
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "response",
                "baseline_model_response",
                "context",
                "instruction",
                "summarization_quality/raw_output",
                "rubrics",
            ]
        )
        assert list(
            test_result.metrics_table["summarization_quality/raw_output"].values
        ) == [
            ["raw_output_sample_1.1", "raw_output_sample_1.2"],
            ["raw_output_sample_2.1", "raw_output_sample_2.2"],
        ]


@pytest.mark.usefixtures("google_auth_mock")
class TestAgentEvaluation:
    def setup_method(self):
        vertexai.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
        )

    def teardown_method(self):
        initializer.global_pool.shutdown(wait=True)

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_runnable_response_eval_with_runnable_inference(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        mock_runnable = mock.create_autospec(reasoning_engines.Queryable, instance=True)
        mock_runnable.query.return_value = _MOCK_RUNNABLE_INFERENCE_RESPONSE

        test_metrics = [PointwisePreview.COHERENCE]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_AGENT_EVAL_DATASET_WITHOUT_RESPONSE, metrics=test_metrics
        )
        mock_metric_results = _MOCK_COHERENCE_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate(
                runnable=mock_runnable,
                prompt_template="test prompt template",
            )

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["coherence/mean"] == 4.5
        assert test_result.summary_metrics["coherence/std"] == pytest.approx(0.7, 0.1)
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "reference_trajectory",
                "response",
                "latency_in_seconds",
                "failure",
                "predicted_trajectory",
                "coherence/score",
                "coherence/explanation",
            ]
        )
        assert list(test_result.metrics_table["coherence/score"].values) == [5, 4]
        assert list(test_result.metrics_table["coherence/explanation"].values) == [
            "explanation",
            "explanation",
        ]

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_runnable_trajectory_eval_with_runnable_inference(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        mock_runnable = mock.create_autospec(reasoning_engines.Queryable, instance=True)
        mock_runnable.query.return_value = _MOCK_RUNNABLE_INFERENCE_RESPONSE

        test_metrics = ["trajectory_exact_match"]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_AGENT_EVAL_DATASET_WITHOUT_RESPONSE, metrics=test_metrics
        )
        mock_metric_results = _MOCK_TRAJECTORY_EXACT_MATCH_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate(runnable=mock_runnable)

        assert test_result.summary_metrics["row_count"] == 2
        assert test_result.summary_metrics["trajectory_exact_match/mean"] == 0.5
        assert test_result.summary_metrics[
            "trajectory_exact_match/std"
        ] == pytest.approx(0.7, 0.1)
        assert set(test_result.metrics_table.columns.values) == set(
            [
                "prompt",
                "response",
                "latency_in_seconds",
                "failure",
                "predicted_trajectory",
                "reference_trajectory",
                "trajectory_exact_match/score",
            ]
        )
        assert list(
            test_result.metrics_table["trajectory_exact_match/score"].values
        ) == [1.0, 0.0]

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_pointwise_autorater_request_config_enabled(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [PointwisePreview.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=test_metrics,
            autorater_config=_AUTORATER_CONFIG,
        )
        mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT_PREVIEW
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ) as mock_evaluate_instances:
            _ = test_eval_task.evaluate()

        api_requests = [
            call.kwargs["request"] for call in mock_evaluate_instances.call_args_list
        ]
        assert api_requests[0].autorater_config == _AUTORATER_CONFIG

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_pointwise_autorater_from_metric(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [
            pointwise_metric_preview.PointwiseMetric(
                metric=constants_preview.Metric.SUMMARIZATION_QUALITY,
                metric_prompt_template=default_templates_preview.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE,
                autorater_config=_AUTORATER_CONFIG,
            )
        ]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED_DEFAULT_FIELDS,
            metrics=test_metrics,
        )
        mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ) as mock_evaluate_instances:
            _ = test_eval_task.evaluate()

        api_requests = [
            call.kwargs["request"] for call in mock_evaluate_instances.call_args_list
        ]
        assert api_requests[0].autorater_config == _AUTORATER_CONFIG

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_pointwise_autorater_from_metric_override(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [
            pointwise_metric_preview.PointwiseMetric(
                metric=constants_preview.Metric.SUMMARIZATION_QUALITY,
                metric_prompt_template=default_templates_preview.SUMMARIZATION_QUALITY_PROMPT_TEMPLATE,
                autorater_config=_AUTORATER_CONFIG,
            )
        ]
        TASK_AUTORATER_CONFIG = AutoraterConfig(
            autorater_model="test_another_model",
            sampling_count=2,
            flip_enabled=False,
        )
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED_DEFAULT_FIELDS,
            metrics=test_metrics,
            autorater_config=TASK_AUTORATER_CONFIG,
        )
        mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ) as mock_evaluate_instances:
            _ = test_eval_task.evaluate()

        api_requests = [
            call.kwargs["request"] for call in mock_evaluate_instances.call_args_list
        ]
        assert api_requests[0].autorater_config != TASK_AUTORATER_CONFIG
        assert api_requests[0].autorater_config == _AUTORATER_CONFIG

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_pairwise_autorater_request_config_enabled(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [PairwisePreview.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=test_metrics,
            autorater_config=_AUTORATER_CONFIG,
            metric_column_mapping={
                "baseline_model_response": "reference",
                "response": "response",
            },
        )
        mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT_PREVIEW
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ) as mock_evaluate_instances:
            _ = test_eval_task.evaluate()

        api_requests = [
            call.kwargs["request"] for call in mock_evaluate_instances.call_args_list
        ]
        assert api_requests[0].autorater_config == _AUTORATER_CONFIG

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_pairwise_autorater_from_metric(self, api_transport):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [
            pairwise_metric_preview.PairwiseMetric(
                metric=constants_preview.Metric.SUMMARIZATION_QUALITY,
                metric_prompt_template=default_templates_preview.PAIRWISE_SUMMARIZATION_QUALITY_PROMPT_TEMPLATE,
                autorater_config=_AUTORATER_CONFIG,
            )
        ]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED_DEFAULT_FIELDS,
            metrics=test_metrics,
        )
        mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ) as mock_evaluate_instances:
            _ = test_eval_task.evaluate()

        api_requests = [
            call.kwargs["request"] for call in mock_evaluate_instances.call_args_list
        ]
        assert api_requests[0].autorater_config == _AUTORATER_CONFIG

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_pairwise_flipping_missing_baseline_model_response_column(
        self, api_transport
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [PairwisePreview.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=test_metrics,
            autorater_config=AutoraterConfig(flip_enabled=True),
            metric_column_mapping={
                "response": "response",
            },
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "Cannot find the `baseline_model_response` column in the"
                    " evaluation dataset to fill the metric prompt template."
                )
            ),
        ):
            test_eval_task.evaluate()

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_pairwise_flipping_incorrect_baseline_model_response_mapping(
        self, api_transport
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [PairwisePreview.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=test_metrics,
            autorater_config=AutoraterConfig(flip_enabled=True),
            metric_column_mapping={
                "response": "response",
                "baseline_model_response": "incorrect_column_name",
            },
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "Cannot find the `baseline_model_response` column in the"
                    " evaluation dataset to fill the metric prompt template."
                )
            ),
        ):
            test_eval_task.evaluate()

    @pytest.mark.parametrize("sampling_count", [-1, 33])
    def test_pairwise_invalid_sampling_count(self, sampling_count):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
        )
        test_metrics = [PairwisePreview.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=test_metrics,
            autorater_config=AutoraterConfig(sampling_count=sampling_count),
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                "autorater_config.sampling_count must be in the range [1, 32]."
            ),
        ):
            test_eval_task.evaluate()

    @pytest.mark.parametrize("api_transport", ["grpc", "rest"])
    def test_pairwise_autorater_request_default_metric_column_mapping(
        self, api_transport
    ):
        aiplatform.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
            api_transport=api_transport,
        )
        test_metrics = [PairwisePreview.SUMMARIZATION_QUALITY]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED_DEFAULT_FIELDS,
            metrics=test_metrics,
            autorater_config=_AUTORATER_CONFIG,
        )
        mock_metric_results = _MOCK_SUMMARIZATION_QUALITY_RESULT
        with mock.patch.object(
            target=gapic_evaluation_services_preview.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ) as mock_evaluate_instances:
            _ = test_eval_task.evaluate()

        api_requests = [
            call.kwargs["request"] for call in mock_evaluate_instances.call_args_list
        ]
        assert api_requests[0].autorater_config == _AUTORATER_CONFIG


@pytest.mark.usefixtures("google_auth_mock")
class TestEvaluationErrors:
    def setup_method(self):
        vertexai.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
        )

    def teardown_method(self):
        initializer.global_pool.shutdown(wait=True)

    def test_evaluate_empty_metrics(self):
        test_eval_task = EvalTask(dataset=_TEST_EVAL_DATASET_WITHOUT_PROMPT, metrics=[])
        with pytest.raises(ValueError, match="Metrics cannot be empty."):
            test_eval_task.evaluate()

    def test_evaluate_invalid_metrics(self):
        metric_name = "invalid_metric"
        with pytest.raises(
            ValueError,
            match=f"Metric name: {metric_name} is not supported.",
        ):
            test_eval_task = EvalTask(
                dataset=_TEST_EVAL_DATASET_WITHOUT_PROMPT, metrics=[metric_name]
            )
            test_eval_task.evaluate()

    def test_evaluate_duplicate_string_metric(self):
        metrics = [
            "exact_match",
            "exact_match",
        ]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_WITHOUT_PROMPT, metrics=metrics
        )
        with pytest.raises(
            ValueError,
            match="Duplicate string metric name found: 'exact_match'",
        ):
            test_eval_task.evaluate()

    def test_evaluate_duplicate_metric_instances(self):
        metrics = [
            Pointwise.SUMMARIZATION_QUALITY,
            Pointwise.SUMMARIZATION_QUALITY,
        ]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=metrics
        )
        with pytest.raises(
            ValueError,
            match=(
                "Duplicate Metric instances of the same metric name found:"
                " 'summarization_quality'"
            ),
        ):
            test_eval_task.evaluate()

    def test_evaluate_invalid_experiment_run_name(self):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_WITHOUT_PROMPT, metrics=_TEST_METRICS
        )
        with pytest.raises(ValueError, match="Experiment is not set"):
            test_eval_task.evaluate(experiment_run_name="invalid_experiment_run_name")

        with pytest.raises(ValueError, match="Experiment is not set"):
            test_eval_task.display_runs()

    def test_evaluate_experiment_name_already_exists(self, mock_experiment_tracker):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_WITHOUT_PROMPT,
            metrics=_TEST_METRICS,
            experiment="test_eval_experiment_name",
        )
        mock_experiment_tracker.experiment_run.return_value = "experiment_run_1"
        with pytest.raises(ValueError, match="Experiment run already exists"):
            test_eval_task.evaluate(experiment_run_name="experiment_run_2")

    def test_evaluate_response_column_and_model_provided(self):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=[_TEST_POINTWISE_METRIC],
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "The `model` parameter or `baseline_model` in pairwise metric is"
                    " specified, but the evaluation `dataset` contains model response"
                    " column or baseline model response column `response`"
                    " to perform bring-your-own-response(BYOR) evaluation. If you would"
                    " like to perform evaluation using the dataset with the"
                    " existing model response column or or baseline model response column"
                    " `response`, please remove `model` parameter in `EvalTask.evaluate()`"
                    " function or `baseline_model` in `PairwiseMetric`."
                )
            ),
        ):
            test_eval_task.evaluate(model=mock.MagicMock())

    def test_evaluate_baseline_response_column_and_baseline_model_provided(self):
        _TEST_PAIRWISE_METRIC._baseline_model = mock.MagicMock()
        eval_dataset = _TEST_EVAL_DATASET_WITHOUT_RESPONSE.copy(deep=True)
        eval_dataset.insert(1, "baseline_model_response", ["baseline", "response"])
        test_eval_task = EvalTask(
            dataset=eval_dataset,
            metrics=[_TEST_PAIRWISE_METRIC],
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "The `model` parameter or `baseline_model` in pairwise metric is"
                    " specified, but the evaluation `dataset` contains model response"
                    " column or baseline model response column `baseline_model_response`"
                    " to perform bring-your-own-response(BYOR) evaluation. If you would"
                    " like to perform evaluation using the dataset with the"
                    " existing model response column or or baseline model response column"
                    " `baseline_model_response`, please remove `model` parameter in"
                    " `EvalTask.evaluate()` function or `baseline_model` in"
                    " `PairwiseMetric`."
                )
            ),
        ):
            test_eval_task.evaluate(model=mock.MagicMock())
        _TEST_PAIRWISE_METRIC._baseline_model = None

    def test_evaluate_baseline_model_provided_but_no_baseline_response_column(self):
        mock_baseline_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_baseline_model.generate_content.return_value = (
            _MOCK_MODEL_INFERENCE_RESPONSE
        )
        mock_baseline_model._model_name = "publishers/google/model/gemini-pro"
        _TEST_PAIRWISE_METRIC._baseline_model = mock_baseline_model

        mock_candidate_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_candidate_model.generate_content.return_value = (
            _MOCK_MODEL_INFERENCE_RESPONSE
        )
        mock_candidate_model._model_name = "publishers/google/model/gemini-1.0-pro"
        mock_metric_results = _MOCK_PAIRWISE_RESULT
        eval_dataset = _TEST_EVAL_DATASET_WITHOUT_RESPONSE.copy(deep=True)
        test_eval_task = EvalTask(
            dataset=eval_dataset,
            metrics=[_TEST_PAIRWISE_METRIC],
        )
        with mock.patch.object(
            target=gapic_evaluation_services.EvaluationServiceClient,
            attribute="evaluate_instances",
            side_effect=mock_metric_results,
        ):
            test_result = test_eval_task.evaluate(
                model=mock_candidate_model,
            )
        _TEST_PAIRWISE_METRIC._baseline_model = None
        assert test_result.summary_metrics["row_count"] == 2

    def test_evaluate_response_column_and_model_not_provided(self):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_SINGLE,
            metrics=[_TEST_POINTWISE_METRIC],
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "Cannot find the `response` column in the evaluation dataset"
                    " to fill the metric prompt template for"
                    " `test_pointwise_metric` metric."
                )
            ),
        ):
            test_eval_task.evaluate()

    def test_evaluate_baseline_model_response_column_not_provided(
        self,
    ):
        test_eval_dataset = _TEST_EVAL_DATASET_SINGLE.copy(deep=True)
        test_eval_dataset.insert(1, "response", ["test", "response"])
        test_eval_task = EvalTask(
            dataset=test_eval_dataset,
            metrics=[_TEST_PAIRWISE_METRIC],
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "Cannot find the `baseline_model_response` column in the"
                    " evaluation dataset to fill the metric prompt template for"
                    " `test_pairwise_metric` metric."
                )
            ),
        ):
            test_eval_task.evaluate()

    @pytest.mark.parametrize("eval_task_version", [EvalTask, EvalTaskPreview])
    def test_evaluate_response_column_not_provided(self, eval_task_version):
        test_eval_dataset = _TEST_EVAL_DATASET_SINGLE
        test_eval_task = eval_task_version(
            dataset=test_eval_dataset,
            metrics=["exact_match"],
        )
        with pytest.raises(
            KeyError,
            match=re.escape(
                (
                    "Required column `response` not found in the evaluation "
                    "dataset. The columns in the evaluation dataset are ['prompt']"
                )
            ),
        ):
            test_eval_task.evaluate()

    @pytest.mark.parametrize("eval_task_version", [EvalTask, EvalTaskPreview])
    def test_evaluate_reference_column_not_provided(self, eval_task_version):
        test_eval_dataset = pd.DataFrame({"response": ["test", "text"]})
        test_eval_task = eval_task_version(
            dataset=test_eval_dataset,
            metrics=["exact_match"],
        )
        with pytest.raises(
            KeyError,
            match=re.escape(
                (
                    "Required column `reference` not found in the evaluation "
                    "dataset. The columns in the evaluation dataset are ['response']"
                )
            ),
        ):
            test_eval_task.evaluate()

    def test_evaluate_reference_or_source_column_not_provided(
        self,
    ):
        test_eval_dataset = pd.DataFrame({"response": ["test", "text"]})
        test_eval_task = EvalTask(
            dataset=test_eval_dataset,
            metrics=[_TEST_COMET, _TEST_METRICX],
        )
        with pytest.raises(
            KeyError,
            match=re.escape(
                (
                    "Required column `source` not found in the evaluation "
                    "dataset. The columns in the evaluation dataset are ['response']"
                )
            ),
        ):
            test_eval_task.evaluate()

    def test_evaluate_invalid_prompt_template_variables(self):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_SINGLE,
            metrics=[Pointwise.FLUENCY],
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "Failed to assemble prompt template: The following column(s) are"
                    " missing: invalid_variable. Please verify prompt_template"
                    " variables {'invalid_variable'} and evaluation dataset"
                    " column names {'prompt'}."
                )
            ),
        ):
            test_eval_task.evaluate(
                prompt_template="test_prompt_template {invalid_variable}",
            )

    def test_evaluate_pairwise_metrics_with_multiple_baseline_models(self):
        mock_baseline_model_1 = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_baseline_model_1._model_name = "publishers/google/model/gemini-1.0-pro"
        mock_baseline_model_2 = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_baseline_model_2._model_name = "publishers/google/model/gemini-1.5-pro"
        test_metrics = [
            pairwise_metric.PairwiseMetric(
                metric="pairwise_metric_1",
                metric_prompt_template="test_prompt_template",
                baseline_model=mock_baseline_model_1,
            ),
            pairwise_metric.PairwiseMetric(
                metric="pairwise_metric_2",
                metric_prompt_template="test_prompt_template",
                baseline_model=mock_baseline_model_2,
            ),
        ]
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED, metrics=test_metrics
        )
        with pytest.raises(
            ValueError,
            match="Not all `PairwiseMetric` instances have the same `baseline_model`",
        ):
            test_eval_task.evaluate()

    def test_evaluate_invalid_model_and_dataset_input(self):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_WITHOUT_PROMPT,
            metrics=[_TEST_POINTWISE_METRIC],
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "The `model` parameter or `baseline_model` in pairwise metric is"
                    " specified, but the evaluation `dataset` contains model response"
                    " column or baseline model response column `response`"
                    " to perform bring-your-own-response(BYOR) evaluation. If you would"
                    " like to perform evaluation using the dataset with the"
                    " existing model response column or or baseline model response column"
                    " `response`, please remove `model` parameter in `EvalTask.evaluate()`"
                    " function or `baseline_model` in `PairwiseMetric`."
                )
            ),
        ):
            test_eval_task.evaluate(
                model=generative_models.GenerativeModel(model_name="invalid_model_name")
            )

    def test_unmatched_metric_column_mapping(self):
        test_eval_task = EvalTask(
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
            metrics=[_TEST_POINTWISE_METRIC_FREE_STRING],
        )
        with pytest.raises(
            ValueError,
            match=re.escape(
                (
                    "The `model` parameter or `baseline_model` in pairwise metric is"
                    " specified, but the evaluation `dataset` contains model response"
                    " column or baseline model response column `response`"
                    " to perform bring-your-own-response(BYOR) evaluation. If you would"
                    " like to perform evaluation using the dataset with the"
                    " existing model response column or or baseline model response column"
                    " `response`, please remove `model` parameter in `EvalTask.evaluate()`"
                    " function or `baseline_model` in `PairwiseMetric`."
                )
            ),
        ):
            test_eval_task.evaluate(
                model=generative_models.GenerativeModel(model_name="invalid_model_name")
            )


@pytest.mark.usefixtures("google_auth_mock")
class TestEvaluationUtils:
    def setup_method(self):
        vertexai.init(
            project=_TEST_PROJECT,
            location=_TEST_LOCATION,
        )

    def teardown_method(self):
        initializer.global_pool.shutdown(wait=True)

    def test_create_evaluation_service_client(self):
        client = utils.create_evaluation_service_client()
        assert isinstance(client, utils._EvaluationServiceClientWithOverride)

    def test_load_dataset_from_dataframe(self):
        data = {"col1": [1, 2], "col2": ["a", "b"]}
        df = pd.DataFrame(data)
        loaded_df = utils.load_dataset(df)
        assert loaded_df.equals(df)

    def test_load_dataset_from_dict(self):
        data = {"col1": [1, 2], "col2": ["a", "b"]}
        loaded_df = utils.load_dataset(data)
        assert isinstance(loaded_df, pd.DataFrame)
        assert loaded_df.to_dict("list") == data

    def test_load_dataset_from_gcs_jsonl(self):
        source = "gs://test_bucket/test_file.jsonl"
        with mock.patch.object(
            utils,
            "_read_gcs_file_contents",
            return_value=_TEST_JSONL_FILE_CONTENT,
        ):
            loaded_df = utils.load_dataset(source)

        assert isinstance(loaded_df, pd.DataFrame)
        assert loaded_df.to_dict("list") == {
            "prompt": ["prompt", "test"],
            "reference": ["reference", "test"],
        }

    def test_load_dataset_from_gcs_csv(self):
        source = "gs://test_bucket/test_file.csv"
        with mock.patch.object(
            utils, "_read_gcs_file_contents", return_value=_TEST_CSV_FILE_CONTENT
        ):
            loaded_df = utils.load_dataset(source)

        assert isinstance(loaded_df, pd.DataFrame)
        assert loaded_df.to_dict("list") == {
            "reference": ["test", "text"],
            "context": ["test", "text"],
            "instruction": ["test", "text"],
        }

    def test_load_dataset_from_bigquery(self):
        source = "bq://project-id.dataset.table_name"
        with mock.patch.object(
            utils, "_load_bigquery", return_value=_TEST_EVAL_DATASET_WITHOUT_PROMPT
        ):
            loaded_df = utils.load_dataset(source)

        assert isinstance(loaded_df, pd.DataFrame)
        assert loaded_df.equals(_TEST_EVAL_DATASET_WITHOUT_PROMPT)

    def test_initialization(self):
        limiter = utils.RateLimiter(rate=2)
        assert limiter.seconds_per_event == 0.5

        with pytest.raises(ValueError, match="Rate must be a positive number"):
            utils.RateLimiter(-1)
        with pytest.raises(ValueError, match="Rate must be a positive number"):
            utils.RateLimiter(0)

    def test_admit(self):
        rate_limiter = utils.RateLimiter(rate=2)

        assert rate_limiter._admit() == 0

        time.sleep(0.1)
        delay = rate_limiter._admit()
        assert delay == pytest.approx(0.4, 0.01)

        time.sleep(0.5)
        delay = rate_limiter._admit()
        assert delay == 0

    def test_sleep_and_advance(self):
        rate_limiter = utils.RateLimiter(rate=2)

        start_time = time.time()
        rate_limiter.sleep_and_advance()
        assert (time.time() - start_time) < 0.1

        start_time = time.time()
        rate_limiter.sleep_and_advance()
        assert (time.time() - start_time) >= 0.5

    def test_thread_safety(self):
        rate_limiter = utils.RateLimiter(rate=2)
        start_time = time.time()

        def target():
            rate_limiter.sleep_and_advance()

        threads = [threading.Thread(target=target) for _ in range(10)]
        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()

        # Verify that the total minimum time should be 4.5 seconds
        # (9 intervals of 0.5 seconds each).
        total_time = time.time() - start_time
        assert total_time >= 4.5

    # TODO(b/361123127) Add test_to_metrics_spec back

    def test_initialize_metric_column_mapping(self):
        metric_column_mapping = {
            "prompt": "prompt2",
            "response": "response1",
            "reference": "reference",
        }
        converted_metric_column_mapping = utils.initialize_metric_column_mapping(
            metric_column_mapping=metric_column_mapping,
            dataset=_TEST_EVAL_DATASET_ALL_INCLUDED,
        )
        assert converted_metric_column_mapping == _EXPECTED_COLUMN_MAPPING

    def test_upload_results(self, mock_storage_blob_from_string):
        with mock.patch("json.dump") as mock_json_dump:
            evaluation.utils.upload_evaluation_results(
                MOCK_EVAL_RESULT,
                _TEST_BUCKET,
                _TEST_FILE_NAME,
                "candidate_model",
                "baseline_model",
                "gs://test-bucket/test-dataset.csv",
                [_TEST_POINTWISE_METRIC, _TEST_PAIRWISE_METRIC],
            )

        mock_storage_blob_from_string.assert_any_call(
            "gs://test-bucket/test-file-name/test-file-name.csv",
            client=mock.ANY,
        )
        mock_storage_blob_from_string.assert_any_call(
            "gs://test-bucket/test-file-name/summary_metrics.json",
            client=mock.ANY,
        )
        mock_json_dump.assert_called_once_with(
            {
                "summary_metrics": MOCK_EVAL_RESULT.summary_metrics,
                "candidate_model_name": "candidate_model",
                "baseline_model_name": "baseline_model",
                "dataset_uri": "gs://test-bucket/test-dataset.csv",
                "metric_descriptions": {
                    "test_pointwise_metric": {
                        "criteria": _CRITERIA,
                        "rating_rubric": _POINTWISE_RATING_RUBRIC,
                    },
                    "test_pairwise_metric": {
                        "criteria": _CRITERIA,
                        "rating_rubric": _PAIRWISE_RATING_RUBRIC,
                    },
                },
            },
            mock.ANY,
        )

    def test_upload_results_with_default_output_file_name(
        self, mock_storage_blob_from_string
    ):
        mock_metric_results = _MOCK_EXACT_MATCH_RESULT
        with mock.patch.object(
            aiplatform_utils, "timestamped_unique_name"
        ) as mock_timestamped_unique_name:
            with mock.patch.object(
                target=gapic_evaluation_services.EvaluationServiceClient,
                attribute="evaluate_instances",
                side_effect=mock_metric_results,
            ):
                mock_timestamped_unique_name.return_value = "2025-02-10-12-00-00-12345"
                eval_dataset = pd.DataFrame(
                    {
                        "response": ["test", "text"],
                        "reference": ["test", "ref"],
                    }
                )
                test_metrics = ["exact_match"]
                test_eval_task = EvalTask(
                    dataset=eval_dataset,
                    metrics=test_metrics,
                    output_uri_prefix=_TEST_BUCKET,
                )
                _ = test_eval_task.evaluate()
        mock_storage_blob_from_string.assert_any_call(
            "gs://test-bucket/eval_results_2025-02-10-12-00-00-12345/summary_metrics.json",
            client=mock.ANY,
        )

    def test_validate_metrics_multiple_rubric_based_metrics(self):
        generation_config = RubricGenerationConfig(prompt_template="abc")
        test_metrics = [
            rubric_based_metric.RubricBasedMetric(
                generation_config=generation_config,
                critique_metric=metric_prompt_template_examples_preview.MetricPromptTemplateExamples.Pointwise.COHERENCE,
            ),
            rubric_based_metric.RubricBasedMetric(
                generation_config=generation_config,
                critique_metric=metric_prompt_template_examples_preview.MetricPromptTemplateExamples.Pointwise.FLUENCY,
            ),
        ]
        test_eval_task = EvalTaskPreview(
            dataset=_TEST_EVAL_DATASET_PROMPT_RESPONSE,
            metrics=test_metrics,
        )
        with pytest.raises(
            ValueError,
            match=re.escape("Multiple rubric based metrics are not supported."),
        ):
            _ = test_eval_task.evaluate()

    def test_default_rubrics_parser_succeeds(self):
        parsed_rubrics = utils_preview.parse_rubrics(_UNPARSED_RUBRIC)
        assert parsed_rubrics == {"questions": ["test_rubric"]}

    def test_default_rubrics_parser_with_invalid_json(self):
        parsed_rubrics = utils_preview.parse_rubrics(_INVALID_UNPARSED_RUBRIC)
        assert parsed_rubrics == {"questions": ""}

    def test_generate_responses_from_gemini_model(self):
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = _MOCK_MODEL_INFERENCE_RESPONSE
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"
        response_list = pre_eval_utils_preview._generate_responses_from_gemini_model(
            mock_model, _TEST_EVAL_DATASET_WITHOUT_RESPONSE
        )
        assert response_list == ["test_response", "test_response"]

    def test_generate_responses_from_gemini_model_with_multimodal_dataset(self):
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = _MOCK_MODEL_INFERENCE_RESPONSE
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"
        mm_prompt_template = "This is a multimodal prompt: {prompt} and {response}"
        response_list = pre_eval_utils_preview._generate_responses_from_gemini_model(
            mock_model,
            _TEST_MULTIMODAL_MODEL_DATASET,
            mm_prompt_template,
        )
        assert response_list == ["test_response", "test_response"]
        assert mock_model.generate_content.call_count == 2
        assert (
            mock_model.generate_content.call_args[0][0][0]
            == "This is a multimodal prompt: "
        )
        assert mock_model.generate_content.call_args[0][0][1] == "text_prompt"
        assert mock_model.generate_content.call_args[0][0][2] == " and "
        assert (
            mock_model.generate_content.call_args[0][0][3].file_data.file_uri
            == "gs://test-bucket/image4.png"
        )


class TestPromptTemplate:
    def test_init(self):
        template_str = "Hello, {name}!"
        prompt_template = evaluation.PromptTemplate(template_str)
        assert prompt_template.template == template_str

    def test_get_variables(self):
        template_str = "Hello, {name}! Today is {day}."
        prompt_template = evaluation.PromptTemplate(template_str)
        assert prompt_template.variables == {"name", "day"}

    def test_format(self):
        template_str = "Hello, {name}! Today is {day}."
        prompt_template = evaluation.PromptTemplate(template_str)
        assembled_prompt = prompt_template.assemble(name="John", day="Monday")
        assert str(assembled_prompt) == "Hello, John! Today is Monday."

    def test_format_missing_variable(self):
        template_str = "Hello, {name}!"
        prompt_template = evaluation.PromptTemplate(template_str)
        assembled_prompt = prompt_template.assemble()
        assert str(assembled_prompt) == "Hello, {name}!"
        assert prompt_template.variables == {"name"}

    def test_partial_format(self):
        template_str = "Hello, {name}! Today is {day}."
        prompt_template = evaluation.PromptTemplate(template_str)
        partially_assembled_prompt = prompt_template.assemble(name="John")

        assert isinstance(partially_assembled_prompt, evaluation.PromptTemplate)
        assert str(partially_assembled_prompt) == "Hello, John! Today is {day}."
        assert partially_assembled_prompt.variables == {"day"}

        assembled_prompt = partially_assembled_prompt.assemble(day="Monday")
        assert str(assembled_prompt) == "Hello, John! Today is Monday."

    def test_str(self):
        template_str = "Hello, world!"
        prompt_template = evaluation.PromptTemplate(template_str)
        assert str(prompt_template) == template_str

    def test_repr(self):
        template_str = "Hello, {name}!"
        prompt_template = evaluation.PromptTemplate(template_str)
        assert repr(prompt_template) == f"PromptTemplate('{template_str}')"

    def test_pointwise_metric_prompt_template(self):
        pointwise_metric_prompt_template = evaluation.PointwiseMetricPromptTemplate(
            criteria={"metric1": "summarization"},
            rating_rubric={"1": "good", "0": "bad"},
            input_variables=["country"],
            instruction="hello",
            metric_definition="this is eval metric",
            evaluation_steps={"step1": "start", "step2": "finish"},
            few_shot_examples=["Q: hi A: hello"],
        )
        assert (
            str(pointwise_metric_prompt_template)
            == _EXPECTED_POINTWISE_PROMPT_TEMPLATE.strip()
        )

    def test_pointwise_metric_prompt_template_with_default_values(self):
        pointwise_metric_prompt_template = evaluation.PointwiseMetricPromptTemplate(
            criteria=_CRITERIA,
            rating_rubric=_POINTWISE_RATING_RUBRIC,
        )
        assert (
            str(pointwise_metric_prompt_template)
            == _EXPECTED_POINTWISE_PROMPT_TEMPLATE_WITH_DEFAULT_VALUES.strip()
        )

    def test_pairtwise_metric_prompt_template(self):
        pairwise_metric_prompt_template = evaluation.PairwiseMetricPromptTemplate(
            criteria={"metric1": "summarization"},
            rating_rubric={"A": "good", "B": "good"},
            input_variables=["country"],
            instruction="hello",
            metric_definition="this is eval metric",
            evaluation_steps={"step1": "start", "step2": "finish"},
            few_shot_examples=["Q: hi A: hello"],
        )
        assert (
            str(pairwise_metric_prompt_template)
            == _EXPECTED_PAIRWISE_PROMPT_TEMPLATE.strip()
        )

    def test_pairtwise_metric_prompt_template_with_default_values(self):
        pairwise_metric_prompt_template = evaluation.PairwiseMetricPromptTemplate(
            criteria=_CRITERIA,
            rating_rubric=_PAIRWISE_RATING_RUBRIC,
        )
        assert (
            str(pairwise_metric_prompt_template)
            == _EXPECTED_PAIRWISE_PROMPT_TEMPLATE_WITH_DEFAULT_VALUES.strip()
        )

    def test_complex_prompt_template_variables(self):
        template_str = """Metric prompt template
instructions ...
Here are some JSON structures
{
  "Function API spec": You may use default python libraries,
  "example": test test
}
Output format prompt with JSON:
The answer should be a json alone which follows the json structure below:
{
  "is_the_response_valid": [valid or invalid],
  "reasoning":
  "rewritten response":
}
Here are some actual variables:
{var_1} {var2} {_var_3} {
        var_5_mutli_line
} {VAR_6} {7_var} {{var_9}}
"""
        prompt_template = evaluation.PromptTemplate(template_str)
        assert prompt_template.variables == {
            "var_1",
            "var2",
            "_var_3",
            "VAR_6",
            "var_9",
        }


@pytest.mark.usefixtures("google_auth_mock")
class TestRubricGeneration:
    """Test class for RubricGeneration."""

    def test_rubric_generation_succeeds(self):
        """Test rubric generation using RubricBasedMetric."""
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = (
            _MOCK_MODEL_RUBRIC_GENERATION_RESPONSE
        )
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"
        rbe = evaluation_preview.metrics.RubricBasedMetric(
            generation_config=RubricGenerationConfig(
                prompt_template="Generate rubrics for the given prompt: {prompt}",
                model=mock_model,
            ),
            critique_metric=metric_prompt_template_examples_preview.MetricPromptTemplateExamples.Pointwise.COHERENCE,
        )
        dataset_with_rubrics = rbe.generate_rubrics(_TEST_EVAL_DATASET_PROMPT_RESPONSE)
        assert dataset_with_rubrics.equals(
            _EXPECTED_EVAL_DATASET_PROMPT_RESPONSE_WITH_RUBRICS
        )

    def test_rubric_generation_skipped(self):
        """Test rubric generation using RubricBasedMetric."""
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = (
            _MOCK_MODEL_RUBRIC_GENERATION_RESPONSE
        )
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"
        rbm = evaluation_preview.metrics.RubricBasedMetric(
            generation_config=RubricGenerationConfig(
                prompt_template="Generate rubrics for the given prompt: {prompt}",
                model=mock_model,
            ),
            critique_metric=metric_prompt_template_examples_preview.MetricPromptTemplateExamples.Pointwise.COHERENCE,
        )
        _ = rbm.generate_rubrics(_EXPECTED_EVAL_DATASET_PROMPT_RESPONSE_WITH_RUBRICS)
        assert mock_model.generate_content.call_count == 0

    def test_rubric_generation_default_parsing_fn(self):
        """Test rubric generation using default parsing function."""
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = (
            _MOCK_MODEL_RUBRIC_GENERATION_RESPONSE
        )
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"
        rbm = evaluation_preview.metrics.RubricBasedMetric(
            generation_config=RubricGenerationConfig(
                prompt_template="Generate rubrics for the given prompt: {prompt}",
                model=mock_model,
            ),
            critique_metric=metric_prompt_template_examples_preview.MetricPromptTemplateExamples.Pointwise.COHERENCE,
        )
        dataset_with_rubrics = rbm.generate_rubrics(_TEST_EVAL_DATASET_PROMPT_RESPONSE)
        assert dataset_with_rubrics.equals(
            _EXPECTED_EVAL_DATASET_PROMPT_RESPONSE_WITH_RUBRICS
        )

    def test_rubric_generation_parsing_str(self):
        """Test rubric generation using parsing function that returns str."""
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = (
            _MOCK_MODEL_RUBRIC_GENERATION_RESPONSE
        )
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"

        def parsing_fn(response: str):
            return ["test_rubric1", "test_rubric2"]

        rbm = evaluation_preview.metrics.RubricBasedMetric(
            generation_config=RubricGenerationConfig(
                prompt_template="Generate rubrics for the given prompt: {prompt}",
                model=mock_model,
                parsing_fn=parsing_fn,
            ),
            critique_metric=metric_prompt_template_examples.MetricPromptTemplateExamples.Pointwise.COHERENCE,
        )
        dataset_with_rubrics = rbm.generate_rubrics(_TEST_EVAL_DATASET_PROMPT_RESPONSE)
        assert dataset_with_rubrics.equals(
            _EXPECTED_EVAL_DATASET_PROMPT_RESPONSE_WITH_RUBRICS
        )

    def test_rubric_generation_parsing_additional_fields(self):
        """Test rubric generation using default parsing function with additional fields."""
        mock_model = mock.create_autospec(
            generative_models.GenerativeModel, instance=True
        )
        mock_model.generate_content.return_value = (
            _MOCK_MODEL_RUBRIC_GENERATION_RESPONSE_WITH_ADDITIONAL
        )
        mock_model._model_name = "publishers/google/model/gemini-1.0-pro"
        rbm = evaluation_preview.metrics.RubricBasedMetric(
            generation_config=RubricGenerationConfig(
                prompt_template="Generate rubrics for the given prompt: {prompt}",
                model=mock_model,
            ),
            critique_metric=metric_prompt_template_examples.MetricPromptTemplateExamples.Pointwise.COHERENCE,
        )
        dataset_with_rubrics = rbm.generate_rubrics(_TEST_EVAL_DATASET_PROMPT_RESPONSE)
        expected = _EXPECTED_EVAL_DATASET_PROMPT_RESPONSE_WITH_RUBRICS
        expected["desc"] = ["test_desc", "test_desc", "test_desc"]
        assert dataset_with_rubrics.equals(expected)
